3934c1
@@ -20,6 +20,7 @@
 import java.io.IOException;
 import java.lang.management.ManagementFactory;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.Map;
 
 import org.apache.commons.logging.Log;
@@ -97,29 +98,33 @@
public void load(MapJoinTableContainer[] mapJoinTables,
     HybridHashTableConf nwayConf = null;
     long totalSize = 0;
     int biggest = 0;                                // position of the biggest small table
+    Map<Integer, Long> tableMemorySizes = null;
     if (useHybridGraceHashJoin && mapJoinTables.length > 2) {
       // Create a Conf for n-way HybridHashTableContainers
       nwayConf = new HybridHashTableConf();
 
       // Find the biggest small table; also calculate total data size of all small tables
-      long maxSize = 0; // the size of the biggest small table
+      long maxSize = Long.MIN_VALUE; // the size of the biggest small table
       for (int pos = 0; pos < mapJoinTables.length; pos++) {
         if (pos == desc.getPosBigTable()) {
           continue;
         }
-        totalSize += desc.getParentDataSizes().get(pos);
-        biggest = desc.getParentDataSizes().get(pos) > maxSize ? pos : biggest;
-        maxSize = desc.getParentDataSizes().get(pos) > maxSize ? desc.getParentDataSizes().get(pos)
-                                                               : maxSize;
+        long smallTableSize = desc.getParentDataSizes().get(pos);
+        totalSize += smallTableSize;
+        if (maxSize < smallTableSize) {
+          maxSize = smallTableSize;
+          biggest = pos;
+        }
       }
 
+      tableMemorySizes = divideHybridHashTableMemory(mapJoinTables, desc,
+          totalSize, noConditionalTaskThreshold);
       // Using biggest small table, calculate number of partitions to create for each small table
-      float percentage = (float) maxSize / totalSize;
-      long memory = (long) (noConditionalTaskThreshold * percentage);
+      long memory = tableMemorySizes.get(biggest);
       int numPartitions = 0;
       try {
         numPartitions = HybridHashTableContainer.calcNumPartitions(memory,
-            desc.getParentDataSizes().get(biggest),
+            maxSize,
             HiveConf.getIntVar(hconf, HiveConf.ConfVars.HIVEHYBRIDGRACEHASHJOINMINNUMPARTITIONS),
             HiveConf.getIntVar(hconf, HiveConf.ConfVars.HIVEHYBRIDGRACEHASHJOINMINWBSIZE),
             nwayConf);
@@ -169,9 +174,7 @@
public void load(MapJoinTableContainer[] mapJoinTables,
         long memory = 0;
         if (useHybridGraceHashJoin) {
           if (mapJoinTables.length > 2) {
-            // Allocate n-way join memory proportionally
-            float percentage = (float) desc.getParentDataSizes().get(pos) / totalSize;
-            memory = (long) (noConditionalTaskThreshold * percentage);
+            memory = tableMemorySizes.get(pos);
           } else {  // binary join
             memory = noConditionalTaskThreshold;
           }
@@ -196,6 +199,45 @@
public void load(MapJoinTableContainer[] mapJoinTables,
     }
   }
 
+  private static Map<Integer, Long> divideHybridHashTableMemory(
+      MapJoinTableContainer[] mapJoinTables, MapJoinDesc desc,
+      long totalSize, long totalHashTableMemory) {
+    int smallTableCount = Math.max(mapJoinTables.length - 1, 1);
+    Map<Integer, Long> tableMemorySizes = new HashMap<Integer, Long>();
+    // If any table has bad size estimate, we need to fall back to sizing each table equally
+    boolean fallbackToEqualProportions = totalSize <= 0;
+
+    if (!fallbackToEqualProportions) {
+      for (Map.Entry<Integer, Long> tableSizeEntry : desc.getParentDataSizes().entrySet()) {
+        if (tableSizeEntry.getKey() == desc.getPosBigTable()) {
+          continue;
+        }
+
+        long tableSize = tableSizeEntry.getValue();
+        if (tableSize <= 0) {
+          fallbackToEqualProportions = true;
+          break;
+        }
+        float percentage = (float) tableSize / totalSize;
+        long tableMemory = (long) (totalHashTableMemory * percentage);
+        tableMemorySizes.put(tableSizeEntry.getKey(), tableMemory);
+      }
+    }
+
+    if (fallbackToEqualProportions) {
+      // Just give each table the same amount of memory.
+      long equalPortion = totalHashTableMemory / smallTableCount;
+      for (Integer pos : desc.getParentDataSizes().keySet()) {
+        if (pos == desc.getPosBigTable()) {
+          break;
+        }
+        tableMemorySizes.put(pos, equalPortion);
+      }
+    }
+
+    return tableMemorySizes;
+  }
+
   private String describeOi(String desc, ObjectInspector keyOi) {
     for (StructField field : ((StructObjectInspector)keyOi).getAllStructFieldRefs()) {
       ObjectInspector oi = field.getFieldObjectInspector();
